{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "81949891-c63f-40b4-ba15-0e558e58e618",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "from jax import grad, jit, vmap\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c964d436-96d4-44a5-9816-b5c93c5a7ae1",
   "metadata": {},
   "outputs": [],
   "source": [
    "PREDICTORS = [\"tmax\", \"tmin\", \"rain\"]\n",
    "TARGET = \"tmax_tomorrow\"\n",
    "\n",
    "data = pd.read_csv(\"../../data/clean_weather.csv\", index_col=0)\n",
    "data = data.ffill()\n",
    "\n",
    "scaler = StandardScaler()\n",
    "data[PREDICTORS] = scaler.fit_transform(data[PREDICTORS])\n",
    "\n",
    "split_data = np.split(data, [int(.7*len(data)), int(.85*len(data))])\n",
    "(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = [[d[PREDICTORS].to_numpy(), d[[TARGET]].to_numpy()] for d in split_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0684855a-6429-4973-a245-1727bacbfa22",
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_layers(inputs):\n",
    "    layers = []\n",
    "    for i in range(1, len(inputs)):\n",
    "        layers.append([\n",
    "            np.random.rand(inputs[i-1], inputs[i]) / 5 - .1,\n",
    "            np.ones((1,inputs[i]))\n",
    "        ])\n",
    "    return layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b281248a-dbed-403b-b381-915a2a995aac",
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(layers, x):\n",
    "    for i in range(len(layers)):\n",
    "        x = jnp.matmul(x, layers[i][0]) + layers[i][1]\n",
    "        if i < len(layers) - 1:\n",
    "            x = jnp.maximum(x, 0)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "40a899d9-97ef-4a98-93e4-b64919ea1991",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mse(y, preds):\n",
    "    return jnp.mean((y - preds)**2)\n",
    "\n",
    "def loss(layers, x, y):\n",
    "    preds = forward(layers, x) \n",
    "    return mse(y, preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5db618bf-b2ca-4e04-99d7-d106a248c4cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def backward(layers, x, y):\n",
    "    grads = grad(loss)(layers, x, y)\n",
    "    for layer, g in zip(layers, grads):\n",
    "        layer[0] -= (g[0] + layer[0] * .01) * lr \n",
    "        layer[1] -= g[1] * lr\n",
    "    return layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a29b1fc7-ea08-491e-912a-0b39f071af49",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0 Valid MSE: 4421.083984375\n",
      "Epoch: 10 Valid MSE: 4088.173095703125\n",
      "Epoch: 20 Valid MSE: 3300.42919921875\n",
      "Epoch: 30 Valid MSE: 1643.950439453125\n",
      "Epoch: 40 Valid MSE: 258.70050048828125\n",
      "Epoch: 50 Valid MSE: 40.592655181884766\n",
      "Epoch: 60 Valid MSE: 25.406206130981445\n",
      "Epoch: 70 Valid MSE: 23.1761417388916\n",
      "Epoch: 80 Valid MSE: 22.51946258544922\n",
      "Epoch: 90 Valid MSE: 22.194660186767578\n",
      "Epoch: 100 Valid MSE: 21.970237731933594\n"
     ]
    }
   ],
   "source": [
    "layer_conf = [3,10,10,1]\n",
    "lr = 5e-7\n",
    "batch_size = 32\n",
    "epochs=100\n",
    "\n",
    "layers = init_layers(layer_conf)\n",
    "\n",
    "for epoch in range(epochs+1):\n",
    "\n",
    "    for i in range(0, train_x.shape[0] - batch_size, batch_size):\n",
    "        batch_ind = range(i, min(train_x.shape[0]-1, i + batch_size))\n",
    "        layers = backward(layers, train_x[batch_ind,:].copy(), train_y[batch_ind,:])\n",
    "    \n",
    "    if epoch % 10 == 0:\n",
    "        valid_preds = forward(layers, valid_x)\n",
    "        print(f\"Epoch: {epoch} Valid MSE: {mse(valid_y, valid_preds)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "21d53891-1198-4b4b-852d-7717df4b5df9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(24.077864, dtype=float32)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_preds = forward(layers, test_x)\n",
    "mse(test_y, test_preds)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
